{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ebcf3daa-78ce-4baa-846a-ea5e298c9de5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install -q datasets trl peft bitsandbytes sentencepiece wandb huggingface_hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4d54dddd-dfda-481a-95d4-79d1525524c4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a9bddf0a36c64d3f9c6d4a58c3d59282",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d137e202-a20e-4c46-9285-3c755e5c3955",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mahm-rimer\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: WANDB_PROJECT=hindi_dpo_test\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import gc\n",
    "import torch\n",
    "\n",
    "import transformers\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig\n",
    "from datasets import load_dataset\n",
    "from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training\n",
    "from trl import DPOTrainer\n",
    "import bitsandbytes as bnb\n",
    "import wandb\n",
    "wandb.login()\n",
    "%env WANDB_PROJECT=hindi_dpo_test\n",
    "\n",
    "model_name = \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\"\n",
    "new_model = \"Hindi-SentenceRetrieval-Tinyllama-1.1B\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f0ac7684-68cd-4891-9979-6d1e22bb229b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chatml_format(example):\n",
    "    # Format system\n",
    "    if len(example['task']) > 0:\n",
    "        message = {\"role\": \"system\", \"content\": example['task']}\n",
    "        system = tokenizer.apply_chat_template([message], tokenize=False)\n",
    "    else:\n",
    "        system = \"\"\n",
    "\n",
    "    # Format instruction\n",
    "    message = {\"role\": \"user\", \"content\": example['query']}\n",
    "    prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)\n",
    "\n",
    "    # Format chosen answer\n",
    "    chosen = example['pos'] + \"<|im_end|>\\n\"\n",
    "\n",
    "    # Format rejected answer\n",
    "    rejected = example['neg'] + \"<|im_end|>\\n\"\n",
    "\n",
    "    return {\n",
    "        \"prompt\": system + prompt,\n",
    "        \"chosen\": chosen,\n",
    "        \"rejected\": rejected,\n",
    "    }\n",
    "\n",
    "# Load dataset\n",
    "dataset = load_dataset(\"TokenBender/e5_FT_sentence_retrieval_task_Hindi_mini\")['train']\n",
    "\n",
    "# Save columns\n",
    "original_columns = dataset.column_names\n",
    "\n",
    "# Tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.padding_side = \"left\"\n",
    "\n",
    "# Format dataset\n",
    "dataset = dataset.map(\n",
    "    chatml_format,\n",
    "    remove_columns=original_columns\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fb9be1e7-9677-4dbd-8a75-946f39e232dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['prompt', 'chosen', 'rejected'],\n",
       "    num_rows: 6633\n",
       "})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7282026e-3a22-4d57-b39b-761b3f9cf7b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'prompt': '<|system|>\\nप्रश्न के रूप में एक सेलिब्रिटी का नाम दिए जाने पर, साक्षात्कार और जीवनी पुनर्प्राप्त करें।</s>\\n<|user|>\\nलियोनार्डो डिकैप्रियो एक अमेरिकी अभिनेता, फिल्म निर्माता और पर्यावरण कार्यकर्ता हैं। उन्हें छह अकादमी पुरस्कारों, चार ब्रिटिश अकादमी फिल्म पुरस्कारों और नौ स्क्रीन एक्टर्स गिल्ड पुरस्कारों के लिए नामांकित किया गया है, जिनमें से प्रत्येक पुरस्कार में से एक और ग्यारह नामांकनों में से तीन गोल्डन ग्लोब पुरस्कार जीते हैं। डिकैप्रियो ने 1980 के दशक के अंत में टेलीविजन विज्ञापनों में दिखाई देकर अपने करियर की शुरुआत की। इसके बाद उन्होंने विभिन्न टेलीविजन श्रृंखलाओं में आवर्ती भूमिकाएँ निभाईं, जैसे कि सोप ओपेरा सांता बारबरा और सिटकॉम ग्रोइंग पेन्स। उन्होंने रॉबर्ट डी नीरो के साथ संस्मरण दिस बॉयज़ लाइफ के फिल्म रूपांतरण में अभिनय करने से पहले क्रिटर्स 3 में जोश के रूप में अभिनय करके अपने फिल्म करियर की शुरुआत की। डिकैप्रियो को नाटक वॉट्स ईटिंग गिल्बर्ट ग्रेप (1993) में उनकी सहायक भूमिका के लिए सराहा गया था, और जेम्स कैमरून के महाकाव्य रोमांस टाइटैनिक (1997) के साथ अंतर्राष्ट्रीय प्रसिद्धि प्राप्त करने से पहले, द बास्केटबॉल डायरीज़ (1995) और रोमांटिक ड्रामा रोमियो + जूलियट (1996) नाटक में प्रमुख भूमिकाओं के साथ सार्वजनिक पहचान प्राप्त की, जो उस समय तक की सबसे अधिक कमाई करने वाली फिल्म बन गई। 2000 के दशक से, डिकैप्रियो को फिल्म शैलियों की एक विस्तृत श्रृंखला में उनके काम के लिए आलोचनात्मक प्रशंसा मिली है। उनकी बाद की फिल्मों में द मैन इन द आयरन मास्क (1998), जीवनी अपराध नाटक कैच मी इफ यू कैन (2002) और महाकाव्य ऐतिहासिक नाटक गैंग्स ऑफ न्यूयॉर्क (2002) शामिल हैं, जो निर्देशक मार्टिन स्कॉर्सेज़ के साथ उनके कई सहयोगों में से पहली थी। डिकैप्रियो को राजनीतिक युद्ध थ्रिलर ब्लड डायमंड (2006), नियो-नॉयर क्राइम ड्रामा द डिपार्टेड (2006), जासूसी थ्रिलर बॉडी ऑफ लाइज़ (2008), ड्रामा रिवोल्यूशनरी रोड (2008), मनोवैज्ञानिक थ्रिलर शटर आइलैंड (2010), विज्ञान कथा थ्रिलर इंसेप्शन (2010), जीवनी फिल्म जे. एडगर (2011), पश्चिमी जांगो अनचेन्ड (2012) और पीरियड ड्रामा द ग्रेट गैट्सबी (2013) में उनके प्रदर्शन के लिए सराहा गया था। द एविएटर (2004) में हॉवर्ड ह्यूजेस और द रेवनेंट (2015) में ह्यूग ग्लास के डिकैप्रियो के चित्रण ने उन्हें सर्वश्रेष्ठ अभिनेता-मोशन पिक्चर ड्रामा के लिए गोल्डन ग्लोब पुरस्कार दिलाया। द वुल्फ ऑफ वॉल स्ट्रीट (2013) में जॉर्डन बेल्फोर्ट के रूप में उनके प्रदर्शन ने उन्हें सर्वश्रेष्ठ अभिनेता-मोशन पिक्चर म्यूजिकल या कॉमेडी के लिए गोल्डन ग्लोब पुरस्कार दिलाया। उन्होंने द रेवनेंट के लिए सर्वश्रेष्ठ अभिनेता का अपना पहला अकादमी पुरस्कार भी जीता। डिकैप्रियो एपियन वे प्रोडक्शंस के संस्थापक हैं-एक निर्माण कंपनी जिसने उनकी कुछ फिल्मों और वृत्तचित्र श्रृंखला ग्रीन्सबर्ग (2008-2010) का निर्माण किया है-और पर्यावरण जागरूकता को बढ़ावा देने के लिए समर्पित एक गैर-लाभकारी संगठन लियोनार्डो डिकैप्रियो फाउंडेशन। वह एक प्रतिबद्ध पर्यावरण कार्यकर्ता भी हैं। 2016 में, उन्हें जलवायु परिवर्तन के लिए संयुक्त राष्ट्र शांति दूत नामित किया गया था और जलवायु परिवर्तन से निपटने के लिए उनके काम के लिए विश्व आर्थिक मंच में क्रिस्टल पुरस्कार प्राप्त हुआ था। 2018 में, जल भृंग की एक प्रजाति का नाम उनकी पर्यावरणीय सक्रियता की मान्यता में उनके नाम पर रखा गया था, ग्रौवेलिनस लियोनार्डोडिकाप्रियोई। डिकैप्रियो के पास लॉस फेलिज़, लॉस एंजिल्स में एक घर और बैटरी पार्क सिटी, लोअर मैनहट्टन में एक अपार्टमेंट है। वह जलवायु परिवर्तन पर कार्रवाई के लिए एक मुखर अधिवक्ता रहे हैं और उन्हें पर्यावरण समूहों से प्रशंसा मिली है। जलवायु परिवर्तन के खतरों के बारे में जागरूकता बढ़ाने के उनके प्रयासों और अक्षय ऊर्जा में उनके निवेश के लिए उनकी विशेष रूप से प्रशंसा की गई है। डिकैप्रियो 1992 से शाकाहारी रहे हैं और पर्यावरणवाद और मानवीय कारणों के प्रति अपने समर्पण के लिए जाने जाते हैं। वह पर्यावरण के मुद्दों के बारे में जागरूकता बढ़ाने के अभियानों में शामिल रहे हैं और टीएजी ह्यूअर के ब्रांड एंबेसडर के रूप में कार्य किया है। वह द 11थ आवर जैसे वृत्तचित्रों के निर्माण में भी शामिल रहे हैं। 2010 में, उन्होंने भूकंप के बाद हैती में राहत प्रयासों के लिए $1 मिलियन का दान दिया। 2017 में, उन्होंने तूफान हार्वे के पीड़ितों को $1 मिलियन का दान दिया। 2018 में, उन्होंने गंभीर रूप से लुप्तप्राय सुमात्रा हाथी के संरक्षण और संरक्षण का समर्थन करने के लिए लियोनार्डो डिकैप्रियो फाउंडेशन को $30 लाख का दान दिया। 2020 में, उन्होंने गंभीर रूप से लुप्तप्राय सुमात्रा हाथी के संरक्षण और संरक्षण का समर्थन करने के लिए लियोनार्डो डिकैप्रियो फाउंडेशन को $30 लाख का दान दिया। 2021 में, उन्होंने गंभीर रूप से लुप्तप्राय सुमात्रा हाथी के संरक्षण और संरक्षण का समर्थन करने के लिए लियोनार्डो डिकैप्रियो फाउंडेशन को $43 मिलियन का दान दिया।</s>\\n<|assistant|>\\n', 'chosen': 'लियोनार्डो डिकैप्रियो एक प्रतिभाशाली अभिनेता हैं जिन्हें टाइटैनिक और द रेवनेंट जैसी फिल्मों में उनकी भूमिकाओं के लिए जाना जाता है। ओ पर्यावरण सक्रियतामे सेहो शामिल रहल छथि आ अपन प्रदर्शनक लेल कतेको पुरस्कार प्राप्त कयने छथि। डिकैप्रियो जलवायु परिवर्तन पर कार्रवाई के लिए एक मुखर अधिवक्ता रहे हैं और उन्हें पर्यावरण समूहों से प्रशंसा मिली है। जलवायु परिवर्तन के खतरों के बारे में जागरूकता बढ़ाने के उनके प्रयासों और अक्षय ऊर्जा में उनके निवेश के लिए उनकी विशेष रूप से प्रशंसा की गई है। डिकैप्रियो 1992 से शाकाहारी रहे हैं और पर्यावरणवाद और मानवीय कारणों के प्रति अपने समर्पण के लिए जाने जाते हैं। वह पर्यावरण के मुद्दों के बारे में जागरूकता बढ़ाने के अभियानों में शामिल रहे हैं और टीएजी ह्यूअर के ब्रांड एंबेसडर के रूप में कार्य किया है। वह द 11थ आवर जैसे वृत्तचित्रों के निर्माण में भी शामिल रहे हैं। 2010 में, उन्होंने भूकंप के बाद हैती में राहत प्रयासों के लिए $1 मिलियन का दान दिया। 2017 में, उन्होंने तूफान हार्वे के पीड़ितों को $1 मिलियन का दान दिया। 2018 में, उन्होंने गंभीर रूप से लुप्तप्राय सुमात्रा हाथी के संरक्षण और संरक्षण का समर्थन करने के लिए लियोनार्डो डिकैप्रियो फाउंडेशन को $30 लाख का दान दिया। 2020 में, उन्होंने गंभीर रूप से लुप्तप्राय सुमात्रा हाथी के संरक्षण और संरक्षण का समर्थन करने के लिए लियोनार्डो डिकैप्रियो फाउंडेशन को $30 लाख का दान दिया। 2021 में, उन्होंने गंभीर रूप से लुप्तप्राय सुमात्रा हाथी के संरक्षण और संरक्षण का समर्थन करने के लिए लियोनार्डो डिकैप्रियो फाउंडेशन को $43 मिलियन का दान दिया।<|im_end|>\\n', 'rejected': 'प्रसिद्ध अभिनेता और परोपकारी, लियोनार्डो डिकैप्रियो के साक्षात्कार और जीवनी खोजें, जो अपनी पर्यावरणीय सक्रियता और टाइटैनिक और द रेवनेंट जैसी फिल्मों में पुरस्कार विजेता प्रदर्शन के लिए जाने जाते हैं।<|im_end|>\\n'}\n"
     ]
    }
   ],
   "source": [
    "print(dataset[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2c17bae2-a6a3-4f6c-b842-bc32847def79",
   "metadata": {},
   "outputs": [],
   "source": [
    "# LoRA configuration\n",
    "peft_config = LoraConfig(\n",
    "    r=16,\n",
    "    lora_alpha=16,\n",
    "    lora_dropout=0.05,\n",
    "    bias=\"none\",\n",
    "    task_type=\"CAUSAL_LM\",\n",
    "    target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c8003e61-ae44-400f-8b5c-668c5499bf78",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model to fine-tune\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    torch_dtype=torch.float16,\n",
    "    load_in_4bit=True\n",
    ")\n",
    "model.config.use_cache = False\n",
    "\n",
    "# Reference model\n",
    "ref_model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    torch_dtype=torch.float16,\n",
    "    load_in_4bit=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af86df74-040d-4253-bc41-56c8b0d760f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py:314: UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments we have set it for you, but you should do it yourself in the future.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d6384bdbfec4041a50bc55c3567eabf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/6633 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (4632 > 2048). Running this sequence through the model will result in indexing errors\n",
      "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to <a href='https://wandb.me/wandb-init' target=\"_blank\">the W&B docs</a>."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.2"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/workspace/wandb/run-20240128_065712-os8qoadz</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/ahm-rimer/hindi_dpo_test/runs/os8qoadz' target=\"_blank\">logical-monkey-4</a></strong> to <a href='https://wandb.ai/ahm-rimer/hindi_dpo_test' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/ahm-rimer/hindi_dpo_test' target=\"_blank\">https://wandb.ai/ahm-rimer/hindi_dpo_test</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/ahm-rimer/hindi_dpo_test/runs/os8qoadz' target=\"_blank\">https://wandb.ai/ahm-rimer/hindi_dpo_test/runs/os8qoadz</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:226: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.\n",
      "  warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')\n",
      "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='92' max='621' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 92/621 1:09:37 < 6:49:13, 0.02 it/s, Epoch 0.44/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.703000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.698300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.703000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.675100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.694200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.675100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.675300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.658600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.641100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.634400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.609300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.606500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.607200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.560900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>0.563000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>0.527200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>0.473300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.486900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.454200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.426300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>0.381000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>0.361900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>0.344100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.298400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>0.298100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.256600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>0.245600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>0.214800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>0.190200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.163100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31</td>\n",
       "      <td>0.148100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>0.136700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>33</td>\n",
       "      <td>0.117100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>34</td>\n",
       "      <td>0.097800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>35</td>\n",
       "      <td>0.105300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>36</td>\n",
       "      <td>0.072300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>37</td>\n",
       "      <td>0.077200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>38</td>\n",
       "      <td>0.056100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>39</td>\n",
       "      <td>0.051000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>0.041900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>41</td>\n",
       "      <td>0.035800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>42</td>\n",
       "      <td>0.031700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>43</td>\n",
       "      <td>0.013200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>44</td>\n",
       "      <td>0.014600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>45</td>\n",
       "      <td>0.036300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>46</td>\n",
       "      <td>0.012200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>47</td>\n",
       "      <td>0.012500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>48</td>\n",
       "      <td>0.011200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>49</td>\n",
       "      <td>0.013500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>0.008400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>51</td>\n",
       "      <td>0.004900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>52</td>\n",
       "      <td>0.006900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>53</td>\n",
       "      <td>0.010800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>54</td>\n",
       "      <td>0.006800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>55</td>\n",
       "      <td>0.003900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>56</td>\n",
       "      <td>0.005600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>57</td>\n",
       "      <td>0.002100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>58</td>\n",
       "      <td>0.001800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>59</td>\n",
       "      <td>0.004600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>0.001600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>61</td>\n",
       "      <td>0.002000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>62</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>63</td>\n",
       "      <td>0.001000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>64</td>\n",
       "      <td>0.002900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>65</td>\n",
       "      <td>0.000800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>66</td>\n",
       "      <td>0.004300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>67</td>\n",
       "      <td>0.000700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>68</td>\n",
       "      <td>0.002700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>69</td>\n",
       "      <td>0.000500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>0.002600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>71</td>\n",
       "      <td>0.000600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>72</td>\n",
       "      <td>0.000400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>73</td>\n",
       "      <td>0.000800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>74</td>\n",
       "      <td>0.000700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>75</td>\n",
       "      <td>0.000400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>76</td>\n",
       "      <td>0.000600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>77</td>\n",
       "      <td>0.000900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>78</td>\n",
       "      <td>0.000300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>79</td>\n",
       "      <td>0.001000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>0.000300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>81</td>\n",
       "      <td>0.002800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>82</td>\n",
       "      <td>0.000900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>83</td>\n",
       "      <td>0.000300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>84</td>\n",
       "      <td>0.000200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>85</td>\n",
       "      <td>0.000300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>86</td>\n",
       "      <td>0.010300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>87</td>\n",
       "      <td>0.001800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>88</td>\n",
       "      <td>0.000400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>89</td>\n",
       "      <td>0.000300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>0.000200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Training arguments\n",
    "training_args = TrainingArguments(\n",
    "    per_device_train_batch_size=,\n",
    "    gradient_accumulation_steps=4,\n",
    "    gradient_checkpointing=True,\n",
    "    learning_rate=5e-5,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    num_train_epochs=3,\n",
    "    save_strategy=\"no\",\n",
    "    logging_steps=1,\n",
    "    output_dir=new_model,\n",
    "    optim=\"paged_adamw_32bit\",\n",
    "    warmup_steps=100,\n",
    "    bf16=True,\n",
    "    report_to=\"wandb\",\n",
    ")\n",
    "\n",
    "# Create DPO trainer\n",
    "dpo_trainer = DPOTrainer(\n",
    "    model,\n",
    "    ref_model,\n",
    "    args=training_args,\n",
    "    train_dataset=dataset,\n",
    "    tokenizer=tokenizer,\n",
    "    peft_config=peft_config,\n",
    "    beta=0.1,\n",
    "    max_prompt_length=1024,\n",
    "    max_length=1536,\n",
    ")\n",
    "\n",
    "# Fine-tune model with DPO\n",
    "dpo_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4383edf2-f7e9-4502-b375-8918a9712f80",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
